# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.

"""Groundedness metric."""

import logging

from responsibleai_text.utils.genai_metrics.constants import _CITATION
from responsibleai_text.utils.genai_metrics.scripts._compute import \
    _compute_metric

module_logger = logging.getLogger(__name__)
module_logger.setLevel(logging.INFO)

try:
    import evaluate
except ImportError:
    module_logger.debug(
        'Could not import evaluate, required if using a genai model')

try:
    import datasets
except ImportError:
    module_logger.debug(
        'Could not import datasets, required if using a genai model')

logger = evaluate.logging.get_logger(__name__)

_DESCRIPTION = """The groundedness metric.
"""

_KWARGS_DESCRIPTION = """
**SOME DESCRIPTION**
"""

_TEMPLATE = """
1. 5: The ANSWER follows logically from the information contained in the \
CONTEXT.
2. 1: The ANSWER is logically false from the information contained in the \
CONTEXT.
3. an integer score between 1 and 5 and if such integer score does not \
exists, use 1: It is not possible to determine whether the ANSWER is true or \
false without further information.
Read the passage of information thoroughly and select the correct answer from \
the three answer labels. Read the CONTEXT thoroughly to ensure you know what \
the CONTEXT entails.
Note the ANSWER is generated by a computer system, it can contain certain \
symbols, which should not be a negative factor in the evaluation.

%s

CONTEXT:
{context}

ANSWER:
{prediction}
""".strip()


@evaluate.utils.file_utils.add_start_docstrings(
    _DESCRIPTION, _KWARGS_DESCRIPTION)
class Groundedness(evaluate.Metric):
    def _info(self):
        return evaluate.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features({
                "predictions": datasets.Value("string", id="sequence"),
                "references": datasets.Value("string", id="sequence")}))

    def _compute(self, *, predictions=None, references=None, **kwargs):
        return _compute_metric(
            _TEMPLATE,
            logger,
            kwargs['wrapper_model'],
            prediction=predictions,
            context=references)
